function [X,Y,D,n,ps,Yscale,x1_wave,prevuse_wave,Xscale] = generate_inputs_cubic_rule(sample,wave,tcost,covariate,SMC)
% This function generates the inputs for the EWM cubic rule search algorithm

% Input 
% (1) sample: cross-sectional dataset merged using elec_wave367_cross.csv
% and propensity scores and renamed as data_estimation_pooled_ps
% (2) wave: indicates whether the output is wave-specific (=3,6,or 7) or
% for the pooled sample (=0)
% (3) tcost: program cost per household. >0 represents cost savings; =0
% represents kWh reduction 
% (4) covariate: indicates covariate for analysis: income, size, vintage,
% minimum of baseline consumption, maximum of baseline consumption, or
% standard deviation of consumption
% (5)  SMC: =1 use social marginal cost; =0 use retail electricity price

% Output 
% (1) X: the list of covariates for EWM analysis, typically baseline 
% consumption & the selected covariates
% (2) Y: demeaned and scaled cost of electricity consumption, net of
% program cost
% (3) D: actual Opower treatment status in RCT 
% (4) n: number of households
% (5) ps: propensity scores 
% (6) Yscale: scale factor used to normalize y 
% (7) x1_wave: scaled covariates 
% (8) prevuse_wave: scaled baseline consumption  
% (9) Xscale: scale factor used to normalize x 

%% Define variables 

if wave==0
    usage_wave = sample.post_avg - sample.pre_avg;
    opower_wave = sample.opower; 
    inc_wave = sample.income;
    size_wave = sample.uni_size;
    vintage_wave = sample.vintage;
    prevuse_wave = round(sample.pre_avg);
    
    if covariate=="min" || covariate=="max" || covariate=="std"
        min_baseline_wave = sample.min_baseline;
        max_baseline_wave = sample.max_baseline;
        std_baseline_wave = sample.std_baseline;
    end
else
    usage_wave = table2array(sample(sample.opower_paper_wave==wave,{'post_avg'}))-...
            table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'}));
    opower_wave = table2array(sample(sample.opower_paper_wave==wave,{'opower'}));
    inc_wave = table2array(sample(sample.opower_paper_wave==wave,{'income'}));
    size_wave = table2array(sample(sample.opower_paper_wave==wave,{'uni_size'}));
    vintage_wave = table2array(sample(sample.opower_paper_wave==wave,{'vintage'}));
    prevuse_wave = round(table2array(sample(sample.opower_paper_wave==wave,{'pre_avg'})));
    min_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'min_baseline'}));
    max_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'max_baseline'}));
    std_baseline_wave = table2array(sample(sample.opower_paper_wave==wave,{'std_baseline'}));
end

%% discretize prevuse 
cat_use = 30;
[prevuse_cat,edge] = discretize(prevuse_wave,cat_use); % xx discrete categories
prevuse_dcrt = nan(size(prevuse_cat,1),1);
num = size(prevuse_dcrt,1);
for i = 1:num
    prevuse_dcrt(i,1) = edge(1,prevuse_cat(i,1)); % use the lower edge of each bin for each discrete value
end 

%% discretize unit size
if covariate=="income" || covariate=="size" || covariate=="vintage"
    
cat_size = 30;
[size_cat,edge] = discretize(size_wave,cat_size); 
size_dcrt = nan(size(size_cat,1),1);
num = size(size_dcrt,1);
for i = 1:num
    size_dcrt(i,1) = edge(1,size_cat(i,1)); 
end 

%% discretize vintage
cat_vintage = 30;
[vintage_cat,edge] = discretize(vintage_wave,cat_vintage); 
vintage_dcrt = nan(size(vintage_cat,1),1);
num = size(vintage_dcrt,1);
for i = 1:num
    vintage_dcrt(i,1) = edge(1,vintage_cat(i,1)); 
end 

end
%% discretize min/max/std previous cosumption
if covariate=="min"
cat_use = 30;
[min_cat,edge] = discretize(min_baseline_wave(:,1),cat_use); 
min_dcrt = nan(size(min_cat,1),1);
num = size(min_dcrt,1);
for i = 1:num
   min_dcrt(i,1) = edge(1,min_cat(i,1)); 
end 
elseif covariate=="max"
[max_cat,edge] = discretize(max_baseline_wave(:,1),cat_use);
max_dcrt = nan(size(max_cat,1),1);
num = size(max_dcrt,1);
for i = 1:num
   max_dcrt(i,1) = edge(1,max_cat(i,1)); 
end 
elseif covariate=="std"
[std_cat,edge] = discretize(std_baseline_wave(:,1),cat_use); 
std_dcrt = nan(size(std_cat,1),1);
num = size(std_dcrt,1);
for i = 1:num
   std_dcrt(i,1) = edge(1,std_cat(i,1)); 
end 
end

%% Y in terms of cost savings or energy conservation 
if (SMC)
    if (tcost)
        Y_all = 0.065.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
else
    if (tcost)
        Y_all = 0.177.*usage_wave(:,1)+tcost.*opower_wave(:,1);
    else
        Y_all = usage_wave(:,1);
    end
end

if wave==0 % in the pooled sample
    demean_wave = 1;  %=1 demean within wave, =0 demean by full sample
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
else % wave-specific analysis, always demean by "full" wave sample
    demean_wave = 0;  % DO NOT CHANGE
    [Y_all] = Y_demean(Y_all,sample,demean_wave);
end

%% Define Y, Yscale, D, n, and ps as final inputs 
Yscale = max(abs(Y_all));
Y = Y_all./Yscale; % rescale demeaned outcomes to [-1,1]
D = opower_wave;
n = length(Y); % sample size
if wave==0
    ps = sample.ps;
else
    ps = table2array(sample(sample.opower_paper_wave==wave,{'ps'}));
end

%% Define X, Xscale
switch covariate 
    case 'income' 
        covariates  = [inc_wave inc_wave.^2 inc_wave.^3 prevuse_dcrt];
        x1_wave = inc_wave;
    case 'size' 
        covariates  = [size_dcrt size_dcrt.^2 size_dcrt.^3 prevuse_dcrt];
        x1_wave = size_wave;
    case 'vintage' 
        covariates  = [vintage_dcrt vintage_dcrt.^2 vintage_dcrt.^3 prevuse_dcrt];
        x1_wave = vintage_wave;
    case "min"
        covariates  = [min_dcrt min_dcrt.^2 min_dcrt.^3 prevuse_dcrt];
        x1_wave = min_baseline_wave;
    case "max"
        covariates  = [max_dcrt max_dcrt.^2 max_dcrt.^3 prevuse_dcrt];
        x1_wave = max_baseline_wave;
    case "std"
        covariates  = [std_dcrt std_dcrt.^2 std_dcrt.^3 prevuse_dcrt];
        x1_wave = std_baseline_wave;
end 

Xscale = ones(n,1)*max(abs(covariates));
X = [ones(n,1) covariates./Xscale];

end